import torch
import torch.nn as nn
import torch.nn.functional as F


class MainLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, output):
        fg_score_map = output['fg_score_map']
        # clip_emb = output['clip_emb']
        # dim = clip_emb.size(1)
        #
        # tmp_map = fg_score_map.masked_fill(output['mask'][128].unsqueeze(1) == 0, 0)
        # template = torch.from_numpy(np.asarray([
        #     [1.0, 1.0, 1.0],
        #     [1.0, 0.0, 1.0],
        #     [1.0, 1.0, 1.0]
        # ])).unsqueeze(0).unsqueeze(1).float()
        # a = F.conv2d(tmp_map, template.cuda(), stride=1, padding=1).squeeze(1)
        # template = torch.from_numpy(np.asarray([
        #     [np.eye(dim), np.eye(dim), np.eye(dim)],
        #     [np.eye(dim), np.zeros([dim, dim]), np.eye(dim)],
        #     [np.eye(dim), np.eye(dim), np.eye(dim)]
        # ])).float().permute((2, 3, 0, 1))
        # b = (F.conv2d(clip_emb, template.cuda(), stride=1, padding=1) * clip_emb).sum(dim=1) + 8
        # c = (F.conv2d(clip_emb * tmp_map, template.cuda(), stride=1, padding=1) * clip_emb).sum(dim=1) + a
        # res = c / b
        # fg_score_map = fg_score_map + res.unsqueeze(1)

        # for i in res[0]:
        #     print(i.tolist())
        #
        # exit(0)
        pw_loss, reg_loss, ps_loss = self.cal_pixel_wise_loss(output['ali_score_map'].squeeze(1),
                                                              output['fine_gt_mask'][128], output['mask'][128],
                                                              fg_score_map.squeeze(1))
        # contrast_loss = -torch.log(output['contrast_score'] + 1e-10).mean()
        # diversity_loss = output['diversity_loss'].mean()
        # same_loss = output['same_loss'].mean()
        final_loss = pw_loss  # + 5e-1 * contrast_loss + 5e-1 * ps_loss  # + 1e-2 * diversity_loss
        loss_dict = {
            'pw_loss': pw_loss.item(),
            # 'contrast_loss': contrast_loss.item(),
            # # 'reg_loss': reg_loss.item(),
            # 'diversity_loss': diversity_loss.item(),
            # 'ps_loss': ps_loss.item(),
        }
        return final_loss, loss_dict, None

    @staticmethod
    def cal_pixel_wise_loss(score, gt, mask, fg_score_map):
        bsz, h, w = fg_score_map.size()
        tmp_map = fg_score_map.reshape(bsz, h * w)
        tmp_gt = gt.reshape(bsz, h * w)
        num_pix = mask.sum(dim=-1).sum(dim=-1)
        reg_loss = 0.0
        # for i in range(bsz):
        #     tmp_i = tmp[i, tmp_mask[i] == 1]
        #     reg_loss += tmp_i.mean()
        #     min_i, max_i = torch.min(tmp_i), torch.max(tmp_i)
        #     tmp[i] = 0.5 * (tmp[i] - min_i + 1e-10) / (max_i - min_i + 1e-10) + 0.5
        # reg_loss = reg_loss / bsz
        # norm_fg_score_map = tmp.reshape(bsz, h, w)
        #
        supervised_loss = -(1.5 * gt * torch.log(score + 1e-10)
                            + (1 - gt) * torch.log(1 - score + 1e-10))
        supervised_loss = supervised_loss.masked_fill(mask == 0, 0).sum(dim=-1).sum(dim=-1)
        supervised_loss = (supervised_loss / num_pix)
        return supervised_loss.mean(), reg_loss, 0.0

        tmp_map1 = F.softmax(tmp_map.masked_fill(tmp_gt == 0, float('-1e30')), dim=-1).reshape(bsz, h, w)
        pos_loss = (-1.0 * tmp_map1 * torch.log(score + 1e-10)).sum(dim=-1).sum(dim=-1)
        # pos_loss = (-1.0 * gt * torch.log(score + 1e-10)).sum(dim=-1).sum(dim=-1)
        num_pos = (tmp_gt == 1).sum(dim=-1)

        idx = torch.argsort(tmp_map1.reshape(bsz, h * w), dim=-1, descending=True)
        score1 = score.reshape(bsz, h * w)
        ps_loss = []
        ratio = 0.90
        for i in range(bsz):
            loss_i = -torch.log(1 - score1[i, idx[i, int(ratio * num_pos[i]):num_pos[i]]] + 1e-10)
            # print(loss_i)
            if int(ratio * num_pos[i]) == int(num_pos[i]):
                continue
            ps_loss.append(loss_i.mean())
            # print(int(0.9 * num_pos[i]), int(num_pos[i]), score1[i].size())
        ps_loss = torch.stack(ps_loss, 0)
        neg_loss = (-(1 - gt) * torch.log(1 - score + 1e-10)).masked_fill(mask == 0, 0).sum(dim=-1).sum(dim=-1)
        # assert not torch.isnan(neg_loss).any()
        # assert not torch.isnan(pos_loss).any()
        supervised_loss = (pos_loss * num_pos + neg_loss) / num_pix

        return supervised_loss.mean(), reg_loss, ps_loss.mean()

    @staticmethod
    def cal_pixel_wise_loss2(score, gt, mask, fg_score_map):
        bsz, h, w = fg_score_map.size()
        num_pix = mask.sum(dim=-1).sum(dim=-1)
        reg_loss = 0.0

        tmp_gt = gt.reshape(bsz, h * w)
        tmp_map = fg_score_map.reshape(bsz, h * w).masked_fill(tmp_gt == 0, float('-1e30'))
        score1 = score.reshape(bsz, h * w)
        num_pos = (tmp_gt == 1).sum(dim=-1)

        idx = torch.argsort(tmp_map.reshape(bsz, h * w), dim=-1, descending=True)
        ps_loss = []
        ratio = 0.9
        for i in range(bsz):
            loss_i = -torch.log(1 - score1[i, idx[i, int(ratio * num_pos[i]):num_pos[i]]] + 1e-10)
            # print(loss_i)
            if int(ratio * num_pos[i]) == int(num_pos[i]):
                continue
            tmp_map[i, idx[i, int(ratio * num_pos[i]):num_pos[i]]] = float('-1e30')
            ps_loss.append(loss_i.mean())
            # print(int(0.9 * num_pos[i]), int(num_pos[i]), score1[i].size())
        ps_loss = torch.stack(ps_loss, 0)

        tmp_map1 = F.softmax(tmp_map, dim=-1).reshape(bsz, h, w)
        pos_loss = (-1.0 * tmp_map1 * torch.log(score + 1e-10)).sum(dim=-1).sum(dim=-1)
        neg_loss = (-(1 - gt) * torch.log(1 - score + 1e-10)).masked_fill(mask == 0, 0).sum(dim=-1).sum(dim=-1)
        # assert not torch.isnan(neg_loss).any()
        # assert not torch.isnan(pos_loss).any()
        supervised_loss = (pos_loss * num_pos + neg_loss) / num_pix

        return supervised_loss.mean(), reg_loss, ps_loss.mean()

class Sum_image_loss(nn.Module):
    def __init__(self):
        super(Sum_image_loss, self).__init__()

    def forward(self):
        return


class NormalSupervisedLoss(nn.Module):
    def __init__(self):
        super().__init__()

    @staticmethod
    def cal_loss(score, gt, mask, score_map=None):
        # num_pix = mask.sum(dim=-1).sum(dim=-1)
        supervised_loss = -(1.0 * gt * torch.log(score + 1e-10)
                            + (1 - gt) * torch.log(1 - score + 1e-10))
        # supervised_loss = supervised_loss.masked_fill(mask == 0, 0).sum(dim=-1).sum(dim=-1)
        # supervised_loss = (supervised_loss / num_pix)
        return supervised_loss

    @staticmethod
    def cal_loss2(score, gt, mask, fg_score_map, fine_gt_mask=None):
        bsz, h, w = fg_score_map.size()

        # if fine_gt_mask is not None:
        #     supervised_loss = NormalSupervisedLoss.cal_loss(score, gt, mask)
        #     tmp_gt1 = gt.reshape(bsz, h * w)
        #     tmp_gt2 = fine_gt_mask.reshape(bsz, h * w)
        #     tmp_map = fg_score_map.reshape(bsz, h * w).masked_fill(tmp_gt1 == 0, float('0'))
        #
        #     test_loss = -tmp_gt2 * torch.log(tmp_map + 1e-10) - \
        #                 (1 - tmp_gt2) * torch.log(1 - tmp_map + 1e-10)
        #     test_loss = test_loss.sum(dim=-1) / tmp_gt1.sum(dim=-1).float()
        #
        #     return supervised_loss + test_loss

        num_pix = mask.sum(dim=-1).sum(dim=-1)

        tmp_gt = gt.reshape(bsz, h * w)
        tmp_map = fg_score_map.reshape(bsz, h * w).masked_fill(tmp_gt == 0, float('-1e30'))
        score1 = score.reshape(bsz, h * w)
        num_pos = (tmp_gt == 1).sum(dim=-1)

        idx = torch.argsort(tmp_map.reshape(bsz, h * w), dim=-1, descending=True)
        ps_loss = []
        ratio = 0.9
        for i in range(bsz):
            loss_i = -torch.log(1 - score1[i, idx[i, int(ratio * num_pos[i]):num_pos[i]]] + 1e-10)
            # print(loss_i)
            # tmp_map[i, idx[i, :int(ratio * num_pos[i])]] = 1
            if int(ratio * num_pos[i]) == int(num_pos[i]):
                continue
            tmp_map[i, idx[i, int(ratio * num_pos[i]):num_pos[i]]] = float('-1e30')
            ps_loss.append(loss_i.mean())
            # print(int(0.9 * num_pos[i]), int(num_pos[i]), score1[i].size())
        ps_loss = torch.stack(ps_loss, 0)

        tmp_map1 = F.softmax(tmp_map, dim=-1).reshape(bsz, h, w)
        pos_loss = (-1.5 * tmp_map1 * torch.log(score + 1e-10)).sum(dim=-1).sum(dim=-1)
        neg_loss = (-(1 - gt) * torch.log(1 - score + 1e-10)).masked_fill(mask == 0, 0).sum(dim=-1).sum(dim=-1)
        # assert not torch.isnan(neg_loss).any()
        # assert not torch.isnan(pos_loss).any()
        supervised_loss = (pos_loss * num_pos + neg_loss) / num_pix
        return supervised_loss + 5e-1 * ps_loss

    def forward(self, output, full=None):
        final_loss = 0.0
        loss_dict = {}
        cnt = 0
        # for k, v in output['ali_score_map'].items():
        #     loss = self.cal_loss(v.squeeze(1), output['coarse_gt_mask'][k],
        #                          output['mask'][k], None).mean()
        #
        #     final_loss += loss
        #     cnt += 1
        #     loss_dict.update({'loss{}'.format(k): loss.item()})

        only_full = True

        if not only_full:
            bsz = full.size(0)
            num_full = full.sum()
            for k, v in output['ali_score_map'].items():
                if k != 256:
                    loss1 = self.cal_loss(v.squeeze(1), output['coarse_gt_mask'][k], output['mask'][k], None)
                    loss = loss1.mean()
                    # loss = full * loss1 + (1 - full) * loss2
                    # if only_full:
                    #     loss = loss1.mean()
                    # else:
                    #     loss2 = self.cal_loss(v.squeeze(1), output['coarse_gt_mask'][k], output['mask'][k], None)
                    #     loss = (full * loss1).sum() / (num_full + 1e-10) \
                    #            + 1e-1 * ((1 - full) * loss2).sum() / (bsz - num_full)
                else:
                    loss2 = self.cal_loss2(v.squeeze(1), output['coarse_gt_mask'][k],
                                           output['mask'][k], output['fg_score_map'].squeeze(1))
                    loss = loss2.mean()
                if k != 256:
                    continue
                final_loss += loss
                cnt += 1
                loss_dict.update({'loss{}'.format(k): loss.item()})

            contrast_loss = -torch.log(output['contrast_score'] + 1e-10).mean()
            final_loss += 1.0 * contrast_loss
            loss_dict.update({'contrast_loss': contrast_loss.item()})
        else:
            for k, v in output['ali_score_map'].items():
                loss1 = self.cal_loss(v.squeeze(1), output['fine_gt_mask'][k], output['mask'][k], None)
                loss = loss1.mean()

                final_loss += loss
                cnt += 1
                loss_dict.update({'loss{}'.format(k): loss.item()})

        return final_loss / cnt, loss_dict, None
